Skip to content

Conversation

@lhutton1
Copy link
Contributor

Previously the following IR:

tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64>

Would result in a crash with the assertion:

expected dense element bit width 64 to match data size 32 for type i64

This commit ensures that zero is constructed with the correct bitwidth while folding, therefore fixing the crash.

Previously the following IR:
```
tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64>
```
Would result in a crash with the assertion:
```
expected dense element bit width 64 to match data size 32 for type i64
```

This commit ensures that zero is constructed with the correct
bitwidth while folding, therefore fixing the crash.

Change-Id: I4531d9f36fb2e682d46075229ad15f12f53d7cf5
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

Previously the following IR:

tosa.argmax %arg0 {axis = 0 : i32} : (tensor&lt;1xi8&gt;) -&gt; tensor&lt;i64&gt;

Would result in a crash with the assertion:

expected dense element bit width 64 to match data size 32 for type i64

This commit ensures that zero is constructed with the correct bitwidth while folding, therefore fixing the crash.


Full diff: https://github.com/llvm/llvm-project/pull/163583.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+6-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+9)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index caf80165fc640..99b7cda49094e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
       !outputTy.hasStaticShape())
     return {};
 
-  if (inputTy.getDimSize(getAxis()) == 1)
-    return DenseElementsAttr::get(outputTy, 0);
+  const Type outputElementTy = getElementTypeOrSelf(outputTy);
+  if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
+    const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
+    const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
+    return DenseElementsAttr::get(outputTy, zero);
+  }
 
   return {};
 }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e8525a5d2ed62..7574afa215e78 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -9,6 +9,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
 
 // -----
 
+// CHECK-LABEL: @test_argmax_fold_i64_index
+func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor<i64> {
+  // CHECK: "tosa.const"() <{values = dense<0> : tensor<i64>}> : () -> tensor<i64>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64>
+  return %0 : tensor<i64>
+}
+
+// -----
+
 // CHECK-LABEL: @pad_wh_avg_pool2d_fold
 func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
   // CHECK-NOT: tosa.pad

@lhutton1
Copy link
Contributor Author

cc @udaya-ranga @Tai78641 @psunn

@Tai78641
Copy link
Contributor

LGTM

@GeorgeARM GeorgeARM merged commit 949148c into llvm:main Oct 22, 2025
13 checks passed
@lhutton1 lhutton1 deleted the fix-argmax-folder branch October 23, 2025 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants